Python 中的 generator

一般用生成斐波那契数列做例子:

1
2
3
4
5
6
def fab(max): 
n, a, b = 0, 0, 1
while n < max:
print b
a, b = b, a + b
n = n + 1

这段代码打印斐波那契数列中的前Max项–问题在于只能使用 print 来输出,代码的复用性不高。

yield关键字可以用来将 print 语句需要输出的数列项“保存”起来,以便调用。用法就是将 print b 替换为 yield b,yield 关键字的作用就类似于“在此处保存”:

1
2
3
4
5
6
def fab(max): 
n, a, b = 0, 0, 1
while n < max:
yield b
a, b = b, a + b
n = n + 1

这样相当于保存了若干项符合规律的数列,但只保存了数列的“规律”,而不占用内存。
在需要产生项的时候,边循环边计算,这种方式叫做生成器(generator)。

简单地讲,yield 的作用就是把一个函数变成一个 generator,带有 yield 的函数不再是一个普通函数,Python 解释器会将其视为一个 generator,调用 fab(5) 不会执行 fab 函数,而是返回一个 iterable 对象!在 for 循环执行时,每次循环都会执行 fab 函数内部的代码,执行到 yield b 时,fab 函数就返回一个迭代值,下次迭代时,代码从 yield b 的下一条语句继续执行,而函数的本地变量看起来和上次中断执行前是完全一样的,于是函数继续执行,直到再次遇到 yield。

一、generator 的使用

以上段代码为例,调用 a = fab(5) 是,返回的 a 为一个 generator 而不是一个 list 。一个重要的区别就是,声明时 list 使用[],而 generator 使用()小括号。以廖雪峰老师的代码为例:

1
2
3
4
5
6
>>> L = [x * x for x in range(10)]
>>> L
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
>>> g = (x * x for x in range(10))
>>> g
<generator object <genexpr> at 0x1022ef630>

可以看出 generator 内的元素并不可以直接打印,可以通过 next() 函数来获得 generator 的下一个返回值:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
>>> next(g)
0
>>> next(g)
1
>>> next(g)
4
>>> next(g)
9
>>> next(g)
16
>>> next(g)
25
>>> next(g)
36
>>> next(g)
49
>>> next(g)
64
>>> next(g)
81
>>> next(g)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
StopIteration

当然,最常用的应该是使用 for 循环:

1
2
3
4
5
6
7
8
9
>>> for n in g:
... print(n)
...
0
1
4
9
16
......

需要注意的是,python3.0之前的版本中,yield 相当于一个函数的返回值,因此不允许 return 具体值,会抛出SyntaxError: ‘return’ with argument inside generator异常。
3.0之后的版本允许在 yield 之前 return 一个值。

二、练习

使用 generator 打印杨辉三角

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def triangles(ls):     # triangle 函数生成杨辉三角的某一行
if ls == 1:
trianglerow = [1]
elif ls == 2:
trianglerow = [1,1]
else:
trianglerow = [1]
n = 0
while(n<ls-2):
trianglerow.append(triangles(ls-1)[n]+triangles(ls-1)[n+1])
n = n+1
trianglerow.append(1)
return trianglerow

def createtriangles(r): # createtriangles 函数 生成 generator
n = 1
while n <= r:
yield triangles(n)
n = n+1